// file: arch/hardware/apic.c
// autor: jiangxinpeng
// time: 2023.1.23
// copyright: (C) 2020-2050 by jiangxinpeng,all right are reserved.

#include <arch/ioapic.h>
#include <arch/acpi.h>
#include <arch/interrupt.h>
#include <arch/gate.h>
#include <os/hardirq.h>
#include <os/debug.h>
#include <os/memcache.h>
#include <arch/memio.h>
#include <arch/interrupt.h>
#include <lib/type.h>
#include <lib/stdlib.h>
#include <lib/errno.h>
#include <lib/stddef.h>
#include <lib/string.h>

ioapic_t ioapic[2];

int ioapic_present = 0;

irq2pin_map_t irq2pin_map[IRQ_NUM_MAX];   // irq to pin map table
static void *irqfunc[IOAPIC_PRT_NUM_MAX]; // record irq generic functon
static uint8_t ioapic_nums = 0;
static acpi_madt_t *madt = NULL;
static uint8_t irq_alloc = 0;

LIST_HEAD(ioapic_list_head);

static int IoAPICInitEvery(ioapic_t *ioapic, uint32_t iobase);

static void IoAPICEOF(irqno_t irq)
{
    IoAPICSendEOF(irq);
}

static void IoAPICRegister(irqno_t irq, irq_handler_t func)
{
    int idx = irq / IOAPIC_PRT_NUM_MAX;
    int prt_idx = irq % IOAPIC_PRT_NUM_MAX;

    ioapic_t *ioapic = IoAPICFind(idx);

    IoAPICIrqRegister(ioapic, irq);
    IrqRegisterInterrupt(irq, (interrupt_handler_t)func);
}

static void IoAPICUnRegister(irqno_t irq)
{
    int idx = irq / IOAPIC_PRT_NUM_MAX;
    int prt_idx = irq % IOAPIC_PRT_NUM_MAX;

    ioapic_t *ioapic = IoAPICFind(idx);

    IoAPICIrqUnReigster(ioapic, irq);
    IrqUnRegisterInterrupt(irq);
}

static void IoAPICDisable(irqno_t irq)
{
    uint8_t idx = irq / IOAPIC_PRT_NUM_MAX;
    uint8_t prt_idx = irq % IOAPIC_PRT_NUM_MAX;

    IoAPICPRTDisable(IoAPICFind(idx), prt_idx);
}

static void IoAPICEnable(irqno_t irq)
{
    uint8_t idx = irq / IOAPIC_PRT_NUM_MAX;
    uint8_t prt_idx = irq % IOAPIC_PRT_NUM_MAX;

    IoAPICPRTDisable(IoAPICFind(idx), prt_idx);
}

static void IoAPICScanAll()
{
    madt = (acpi_madt_t *)AcpiGetMadtBase();

    uint32_t len = madt->len;
    uint8_t *p = (address_t)madt + sizeof(acpi_madt_t);
    uint8_t *end = (address_t)madt + len;

    while (p < end)
    {
        uint8_t type = p[0];
        uint32_t len = p[1];
        uint8_t *data = (uint8_t *)(p + 2);

        switch (type) // switch entry type
        {
        case MADT_PRCESSOR_IOAPIC: // found
        {
            madt_ioapic_entry_t *ioapic_entry = (madt_ioapic_entry_t *)data;

            ioapic[ioapic_nums].ioapic_id = ioapic_entry->ioapic_id;
            ioapic[ioapic_nums].iobase = ioapic_entry->ioapic_base;
            ioapic[ioapic_nums].gsi_base = ioapic_entry->gsi_base;
            list_add_tail(&ioapic[ioapic_nums].list, &ioapic_list_head);
            KPrint("[IOAPIC] found ioapic id %d iobase %x gsi base %x\n", (uint32_t)ioapic->ioapic_id, ioapic->iobase, ioapic->gsi_base);
            ioapic_nums++;
        }
        break;
        default:
            break;
        }
        p += len; // next entry
    }

    KPrint("[IOAPIC] all found IOAPIC num %d\n", ioapic_nums);
}

static void IoAPICInitIrqSource()
{
    int i;
    for (i = 0; i < 16; i++)
    {
        irq2pin_map[i].irq = i;
        irq2pin_map[i].pin = i;
    }
}

static void IoAPICGetIrqSource()
{
    IoAPICInitIrqSource(); // init irq source map

    madt = (acpi_madt_t *)AcpiGetMadtBase();

    uint32_t len = madt->len;
    uint8_t *p = (address_t)madt + sizeof(acpi_madt_t);
    uint8_t *end = (address_t)madt + len;
    uint8_t ioapic_ints = 0;

    while (p < end)
    {
        uint8_t type = p[0];
        uint32_t len = p[1];
        uint8_t *data = (uint8_t *)(p + 2);

        switch (type) // switch entry type
        {
        case MADT_LOCAL_INT_ASSERT: // ioapic interrupt override
        {
            madt_ioapic_int_assert_t *ioapic_int = (madt_ioapic_int_assert_t *)data;

            irq2pin_map[ioapic_int->irq].pin = ioapic_int->gsi;

            KPrint("[ioapic] bus %d irq %d gsi %d\n", (uint32_t)ioapic_int->bus, (uint32_t)ioapic_int->irq, ioapic_int->gsi);
            ioapic_ints++;
        }
        break;
        case MADT_LOCAL_NMI_ASSERT: // ioapic NMI override
        {
            madt_ioapic_nmi_assert_t *ioapic_nmi = (madt_ioapic_nmi_assert_t *)data;
            KPrint("[ioapic] nmi source %d gsi %d\n", (uint32_t)ioapic_nmi->nmi, (uint32_t)ioapic_nmi->gsi);
        }
        break;
        default:
            break;
        }
        p += len; // next entry
    }

    KPrint("[IOAPIC] all found IOAPIC INT source num %d\n", ioapic_ints);
}

// config ISA device interrupt
static void IoAPICIrqConfig()
{
    int i;
    irq2pin_map_t *irq_map;
    for (i = 0; i < IRQ_NUM_MAX; i++)
    {
        irq_map = &irq2pin_map[i];
        IoAPICSetPRT(&ioapic[0], i, 0, 0, IOAPIC_PRT_DELIVERY_MODE_EXINT, IOAPIC_PRT_DESTMODE_PHYSICAL, IOAPIC_PRT_TRIGMODE_LEVEL, 0, IOAPIC_PRT_INT_MASK);
    }
}

int IoAPICInit()
{
    // disable 8259a
    PicDisable();

    // scan all ioapic on system
    IoAPICScanAll();

    ioapic_t *ioapic = NULL;

    // activity all ioapic on system
    list_traversal_all_owner_to_next(ioapic, &ioapic_list_head, list)
    {
        IoAPICInitEvery(ioapic, ioapic->iobase);
    }

    // get ISA device irq with gsi map
    IoAPICGetIrqSource();

    // config ISA device irq
    IoAPICIrqConfig();

    // rebind interrupt controller interface
    /*interrupt_controller.disable = IoAPICDisable;
    interrupt_controller.enable = IoAPICEnable;
    interrupt_controller.init = IoAPICInit;
    interrupt_controller.ack = IoAPICEOF;
    interrupt_controller.install = IoAPICRegister;
    interrupt_controller.uninstall = IoAPICUnRegister;*/
    KPrint("[IOAPIC] ioapic hardware init ok\n");
}

static int IoAPICInitEvery(ioapic_t *ioapic, uint32_t iobase)
{
    // init map
    IoAPICMap(ioapic, iobase);
    // get number of ioapic prt max
    IoAPICGetPRTMax(ioapic);
    // init prt to disable interrupt,fixed,edge,high level trigger
    IoAPICPRTInit(ioapic);
    // interrupt disable all
    IoAPICDisableAll(ioapic);
}

void IoAPICGetPRTMax(ioapic_t *ioapic)
{
    uint32_t ver = IoAPICReadReg(ioapic, IOAPIC_REG_VER_IDX);
    ioapic->prt_num = IOAPIC_VER_PRTNUM(ver) + 1;
    KPrint("[ioapic] version 0x%x max prt number is %d\n", ver & 0xff, ioapic->prt_num);
}

void IoAPICMap(ioapic_t *ioapic, uint32_t iobase)
{
    if (page_enable)
    {
        void *vaddr = (void *)KMemAlloc(IOAPIC_MAP_SIZE);
        if (!vaddr)
            Panic("[ioapic] mem map alloc err\n");
        HalMemIoReMap((intptr_t)vaddr, iobase, IOAPIC_MAP_SIZE);
        ioapic->iobase = (uint8_t *)vaddr;
    }
}

void IoAPICUnMap(ioapic_t *ioapic)
{
    HalMemIoUnMap((intptr_t)ioapic->iobase, IOAPIC_MAP_SIZE);
}

void IoAPICWriteReg(ioapic_t *ioapic, uint32_t idx, uint32_t data)
{
    *(uint32_t *)(ioapic->iobase + IOAPIC_REG_IDX) = idx;
    *(uint32_t *)(ioapic->iobase + IOAPIC_REG_DATA) = data;
}

uint32_t IoAPICReadReg(ioapic_t *ioapic, uint32_t idx)
{
    *(uint32_t *)(ioapic->iobase + IOAPIC_REG_IDX) = idx;
    return *(uint32_t *)(ioapic->iobase + IOAPIC_REG_DATA);
}

uint64_t IoAPICReadPRTReg(ioapic_t *ioapic, uint8_t prt_idx)
{
    if (prt_idx >= IOAPIC_PRT_NUM_MAX)
        return 0;

    uint32_t prt_low, prt_high;

    prt_high = IoAPICReadReg(ioapic, IOAPIC_REG_PRT_IDX + prt_idx * 2 + 1);
    prt_low = IoAPICReadReg(ioapic, IOAPIC_REG_PRT_IDX + prt_idx * 2);

    return ((uint64_t)prt_high) << 32 | ((uint64_t)prt_low);
}

void IoAPICWritePRTReg(ioapic_t *ioapic, uint8_t prt_idx, uint64_t data)
{
    if (prt_idx >= IOAPIC_PRT_NUM_MAX)
        return;

    uint32_t prt_low = (uint32_t)(data >> 32);
    uint32_t prt_high = (uint32_t)data;

    IoAPICWriteReg(ioapic, IOAPIC_REG_PRT_IDX + prt_idx * 2, prt_low);
    IoAPICWriteReg(ioapic, IOAPIC_REG_PRT_IDX + prt_idx * 2 + 1, prt_high);
}

void IoAPICSetPRT(ioapic_t *ioapic, uint8_t prt_idx, uint8_t vector, uint8_t dest, uint8_t delivery_mode, uint8_t dest_mode, uint8_t trig_mode, uint8_t level, uint8_t int_mask)
{
    uint64_t prt = (((uint64_t)delivery_mode & 0x7) << IOAPIC_PRT_DELIVERY_MODE_SHIFT) | (((uint64_t)int_mask & 0x1) << IOAPIC_PRT_INTMASK_SHIFT) | (((uint64_t)dest_mode & 0x1) << IOAPIC_PRT_DESTMODE_SHIFT) | (((uint64_t)dest) << IOAPIC_PRT_DESTINATION_SHIFT) | (((uint64_t)trig_mode & 0x1) << IOAPIC_PRT_TRIGMODE_SHIFT);
    if (delivery_mode == IOAPIC_PRT_DELIVERY_MODE_FIXED)
    {
        if (vector >= 0 && vector <= 32)
            return;
        prt |= vector;
    }

    if (delivery_mode == IOAPIC_PRT_DELIVERY_MODE_NMI)
    {
        prt |= 2;
    }

    if (delivery_mode == IOAPIC_PRT_DELIVERY_MODE_EXINT)
    {
        prt |= 0;
    }

    IoAPICWritePRTReg(ioapic, prt_idx, prt);
}

ioapic_t *IoAPICFind(uint8_t id)
{
    ioapic_t *ioapic;
    list_traversal_all_owner_to_next(ioapic, &ioapic_list_head, list)
    {
        if (ioapic->ioapic_id == id)
            return ioapic;
    }
    return NULL;
}

void IoAPICPRTDisable(ioapic_t *ioapic, int prt_idx)
{
    uint64_t prt = IoAPICReadPRTReg(ioapic, prt_idx);
    prt &= ~(1 << IOAPIC_PRT_INTMASK_SHIFT);
    IoAPICWritePRTReg(ioapic, prt_idx, prt);
}

void IoAPICPRTEnalbe(ioapic_t *ioapic, int prt_idx)
{
    uint64_t prt = IoAPICReadPRTReg(ioapic, prt_idx);
    prt |= (1 << IOAPIC_PRT_INTMASK_SHIFT);
    IoAPICWritePRTReg(ioapic, prt_idx, prt);
}

void IoAPICPRTSetVec(ioapic_t *ioapic, int prt_idx, uint8_t vec)
{
    uint64_t prt = IoAPICReadPRTReg(ioapic, prt_idx);
    prt |= vec;
    IoAPICWritePRTReg(ioapic, prt_idx, prt);
}

void IoAPICDisableAll(ioapic_t *ioapic)
{
    for (int i = 0; i < INTERRUPT_NUM_MAX; i++)
    {
        IoAPICPRTDisable(ioapic, i);
    }
}

void IoAPICEnableAll(ioapic_t *ioapic)
{
    for (int i = 0; i < INTERRUPT_NUM_MAX; i++)
    {
        IoAPICPRTEnalbe(ioapic, i);
    }
}

void IoAPICPRTInit(ioapic_t *ioapic)
{
    int i;
    for (i = 0; i < ioapic->prt_num; i++)
    {
        IoAPICSetPRT(ioapic, i, 0, IOAPIC_DEFAULT_INT_DISPATCH_DEST, IOAPIC_PRT_DELIVERY_MODE_FIXED, IOAPIC_PRT_DESTMODE_PHYSICAL, IOAPIC_PRT_TRIGMODE_EDGE, IOAPIC_PRT_INTPINPOL_HIGH, IOAPIC_PRT_INTMASK_SHIFT);
    }
}

void Irq2PinMapInit(ioapic_t *ioapic)
{
    // clear irq2map table
    memset(irq2pin_map, 4 * INTERRUPT_NUM_MAX, 0);

    uint8_t pin;
    uint8_t irq;
    for (pin = 0; pin < IOAPIC_PRT_NUM_MAX; pin++)
    {
        irq = IoAPICGetGSI(ioapic, pin); // translate to irq
        // bind irq and pin
        Irq2PinMapBind(pin, irq);
        // register a irq descript
        IoAPICIrqRegister(ioapic, irq);
    }
}

void Irq2PinMapBind(uint8_t pin, uint8_t irq)
{
    irq2pin_map[irq_alloc].pin = pin;
    irq2pin_map[irq_alloc].irq = irq;
    irq_alloc++;
}

uint8_t IoAPICGetGSI(ioapic_t *ioapic, uint8_t pin)
{
    uint8_t irq = ioapic->gsi_base;
    irq += pin;
    return irq;
}

// create a irq general handler for irq
static void *IoAPICAllocIrqDes(uint8_t irq)
{
    /// copy a general function
    void *func = KMemAlloc(IRQ_GENERAL_SIZE);
    memcpy(func, IrqEntryGeneral, IRQ_GENERAL_SIZE);

    // set irq number
    *(uint8_t *)(func + 10) = irq;

    return func;
}

void IoAPICIrqRegister(ioapic_t *ioapic, uint8_t irq)
{
    // ISA device has been map
    if (irq <= IRQ15_HARDDISK2)
        return;

    // alloc a free vector
    uint8_t vec = IntVectorAlloc();
    if (vec != 0)
    {
        IoAPICPRTSetVec(ioapic, irq2pin_map[irq].pin, vec);
        irq2pin_map[irq].vector = vec;
    }

    // make irq generial handler
    irqfunc[irq] = IoAPICAllocIrqDes(irq);

    // register irq descript
    SetGateDescript(vec, irqfunc[irq]);
}

void IoAPICIrqUnReigster(ioapic_t *ioapic, uint8_t irq)
{
    // ISA device has been map
    if (irq <= IRQ15_HARDDISK2)
        return;

    // free vector
    IoAPICPRTSetVec(ioapic, irq2pin_map[irq].pin, 0);

    // unregister irq descript
    SetGateDescript(irq2pin_map[irq].vector, irqfunc[irq]);

    // clear vector
    irq2pin_map[irq].vector = 0;
}

void IoAPICWriteEOF(ioapic_t *ioapic)
{
    *(uint32_t *)(ioapic->iobase + IOAPIC_REG_EOF) = 0;
}

void IoAPICSendEOF(irqno_t irq)
{
    uint8_t idx = irq / IOAPIC_PRT_NUM_MAX;
    uint8_t prt_idx = irq % IOAPIC_PRT_NUM_MAX;

    ioapic_t *ioapic = IoAPICFind(idx);
    IoAPICWriteEOF(ioapic);
}
